from torch import nn


# CNN模型
class CNN(nn.Module):
    def __init__(self, y_dimension, model: str):
        super().__init__()
        if model == "kddcup99_all.zst":
            num_input = 448
        elif model == "kddcup_10precent.zst":
            num_input = 432
        self.backbone = nn.Sequential(
            nn.Conv1d(1, 3, kernel_size=2),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(3, 8, kernel_size=2),
            nn.MaxPool1d(2, 2),
            nn.Conv1d(8, 16, kernel_size=2),
        )
        self.flatten = nn.Flatten()
        self.fc = nn.Sequential(
            nn.Linear(num_input, 64),
            nn.ReLU(),
            nn.Linear(64, 64),
            nn.ReLU(),
            nn.Linear(64, y_dimension),
        )

    def forward(self, X):
        X = self.backbone(X)
        X = self.flatten(X)
        logits = self.fc(X)
        return logits
